import random
from .RAFDB import RAFDBDataset
from .SFEW import SFEWDataset
from torch.utils.data import Dataset


class MixedDataset(Dataset):
    """
    The model is weak on the label 'Laughing' or 'Happiness'
    So MixedDataset will use the full RAF-DB plus with SFEW
    """

    def __init__(self, phase, arg, transform=None):
        self.RAF = RAFDBDataset(phase=phase, base_directory=arg.rafdb_base, transform=transform)
        self.SFEW = SFEWDataset(phase=phase, base_directory=arg.sfew_base, transform=transform)

        idx = [x for x in range(len(self.RAF))]
        dataset = [0] * len(self.RAF)
        idx = idx + [x for x in range(len(self.SFEW))]
        dataset = dataset + [1] * len(self.SFEW)
        self.mapper = list(zip(idx, dataset))
        random.shuffle(self.mapper)

    def __len__(self):
        return len(self.RAF) + len(self.SFEW)

    def __getitem__(self, idx):
        idx_data, dataset = zip(*self.mapper)
        idx_internal = idx_data[idx]
        if dataset[idx] == 0:
            return self.RAF[idx_internal]
        else:
            return self.SFEW[idx_internal]
